import wandb
import torch
import random
from tqdm import tqdm
from collections import defaultdict

from utils import utils
from utils.utils import assert_mean_zero_with_mask, remove_mean_with_mask,\
    assert_correctly_masked, sample_center_gravity_zero_gaussian_with_mask
import numpy as np
import qm9.utils as qm9utils
import qm9.visualizer as vis
from qm9.analyze import analyze_stability_for_molecules

from traintest import losses
from traintest.sampling import sample_chain, sample, sample_sweep_conditional

def prepare_batch_data(args, data, device, dtype, partition='Train'):
    x = data['positions'].to(device, dtype)
    node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
    edge_mask = data['edge_mask'].to(device, dtype)
    one_hot = data['one_hot'].to(device, dtype)
    charges = (data['charges'] if args.include_charges else torch.zeros(0)).to(device, dtype)
    if args.dataset == 'qm9_scaffold':
        scaffold_mask = data['scaffold_mask'].to(device, dtype)
    elif args.dataset == 'qm9_scaffold':
        scaffold_mask = data['ring_mask'].to(device, dtype)
    else:
        # Random masking
        bs, ns, _ = node_mask.size()
        scaffold_mask = torch.zeros([bs, ns]).to(device, dtype)
        try:
            masked_length = data['num_atoms']
        except:
            tensor = data['atom_mask'].to(dtype)
            masked_length = [torch.nonzero(row).size(0) for row in tensor]
        for i in range(bs):
            ones_indices = random.sample(range(masked_length[i]), int((1-args.mask_ratio) * masked_length[i]))
            scaffold_mask[i, ones_indices] = 1

    # Resize it to args.ood_element_size
    if args.ood_element_size != one_hot.shape[-1]:
        one_hot = torch.nn.functional.pad(one_hot, (0, args.ood_element_size - one_hot.shape[-1]), value=0)
    x = remove_mean_with_mask(x, node_mask)

    if args.augment_noise > 0:
        # Add noise eps ~ N(0, augment_noise) around points.
        eps = sample_center_gravity_zero_gaussian_with_mask(x.size(), x.device, node_mask)
        x = x + eps * args.augment_noise

    x = remove_mean_with_mask(x, node_mask)
    if args.data_augmentation and partition == 'Train':
        x = utils.random_rotation(x).detach()

    check_mask_correct([x, one_hot, charges], node_mask)
    assert_mean_zero_with_mask(x, node_mask)

    h = {'categorical': one_hot, 'integer': charges}
    return x, h, node_mask, edge_mask, scaffold_mask

def print_nodes_distribution(args, loaders, device, dtype):
    with open('a.txt', 'a+') as f:
        attribute_counts = defaultdict(int)
        size = 0
        for loader in loaders:
            local_attribute_counts = defaultdict(int)
            local_size = 0
            loader_tqdm = tqdm(loader, ncols=80)
            for i, data in enumerate(loader_tqdm):
                x, h, node_mask, edge_mask = prepare_batch_data(args, data, device, dtype)
                attributes = node_mask.squeeze(2).sum(1).long()
                for attribute in attributes.cpu().numpy().tolist():
                    attribute_counts[attribute] += 1
                    local_attribute_counts[attribute] += 1
                local_size += x.size(0)
            f.write(f'\nLoader: {local_size}\n')
            size += local_size
            f.write('{')
            for key, value in attribute_counts.items():
                f.write(f'{key}: {value}, ')
            f.write('}')
        # MARKII delete for save attribute
        print(attribute_counts)
        f.write(f'\nTotal {size}\n')
        f.write('{')
        for key, value in attribute_counts.items():
            f.write(f'{key}: {value}, ')
        f.write('}')

def train_epoch(args, loader, epoch, model, model_dp, model_ema, ema, device, dtype, property_norms, optim,
                nodes_dist, gradnorm_queue, dataset_info, prop_dist):
    model_dp.train()
    model.train()
    nll_epoch = []
    n_iterations = len(loader)
    loader_tqdm = tqdm(loader, ncols=80)
    for i, data in enumerate(loader_tqdm):
        x, h, node_mask, edge_mask, scaffold_mask = prepare_batch_data(args, data, device, dtype)
        if len(args.conditioning) > 0:
            context = qm9utils.prepare_context(args.conditioning, data, property_norms).to(device, dtype)
            assert_correctly_masked(context, node_mask)
        else:
            context = None

        optim.zero_grad()

        # transform batch through flow
        nll, reg_term, mean_abs_z = losses.compute_loss_and_nll(args, model_dp, nodes_dist,
                                                                x, h, node_mask, edge_mask, context, scaffold_mask)
        # standard nll from forward KL
        loss = nll + args.ode_regularization * reg_term
        loss.backward()

        if args.clip_grad:
            grad_norm = utils.gradient_clipping(model, gradnorm_queue)
        else:
            grad_norm = 0.

        optim.step()

        # Update EMA if enabled.
        if args.ema_decay > 0:
            ema.update_model_average(model_ema, model)

        train_des = (f"\rE: {epoch}, i: {i}/{n_iterations}, "
                     f"Loss {loss.item():.2f}, NLL: {nll.item():.2f}, "
                     # f"P-L {property_loss.item():.2f},"
                     # f"RegTerm: {reg_term.item():.3f}, "
                     f"GradNorm: {grad_norm:.2f}")
        loader_tqdm.set_description(train_des)
        nll_epoch.append(nll.item())
        wandb.log({"Batch NLL": nll.item()}, commit=True)
        if args.break_train_epoch:
            break
    wandb.log({"Train Epoch NLL": np.mean(nll_epoch)}, commit=False)


def check_mask_correct(variables, node_mask):
    for i, variable in enumerate(variables):
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def test(args, loader, epoch, eval_model, device, dtype, property_norms, nodes_dist, partition='Test'):
    eval_model.eval()
    with torch.no_grad():
        nll_epoch = 0
        n_samples = 0

        n_iterations = len(loader)

        loader_tqdm = tqdm(loader, ncols=80)
        for i, data in enumerate(loader_tqdm):
            x, h, node_mask, edge_mask, scaffold_mask = prepare_batch_data(args, data, device, dtype, partition)
            batch_size = x.size(0)
            if len(args.conditioning) > 0:
                context = qm9utils.prepare_context(args.conditioning, data, property_norms).to(device, dtype)
                assert_correctly_masked(context, node_mask)
            else:
                context = None

            # transform batch through flow
            nll, reg_term, mean_abs_z = losses.compute_loss_and_nll(args, eval_model, nodes_dist, x, h, node_mask, edge_mask, context, scaffold_mask)
            # standard nll from forward KL

            nll_epoch += nll.item() * batch_size
            n_samples += batch_size
            train_des = (f"\r {partition} \t e: {epoch}, i: {i}/{n_iterations}, "
                         f"NLL: {nll_epoch/n_samples:.2f}")
            loader_tqdm.set_description(train_des)

    return nll_epoch/n_samples


def save_and_sample_chain(model, args, device, dataset_info, prop_dist,
                          epoch=0, id_from=0, batch_id=''):
    one_hot, charges, x = sample_chain(args=args, device=device, flow=model,
                                       n_tries=1, dataset_info=dataset_info, prop_dist=prop_dist)

    vis.save_xyz_file(f'outputs/{args.exp_name}/epoch_{epoch}_{batch_id}/chain/',
                      one_hot, charges, x, dataset_info, id_from, name='chain')

    return one_hot, charges, x


def sample_different_sizes_and_save(model, nodes_dist, args, device, dataset_info, prop_dist,
                                    n_samples=5, epoch=0, batch_size=100, batch_id=''):
    batch_size = min(batch_size, n_samples)
    for counter in range(int(n_samples/batch_size)):
        nodesxsample = nodes_dist.sample(batch_size)
        one_hot, charges, x, node_mask = sample(args, device, model, prop_dist=prop_dist,
                                                nodesxsample=nodesxsample,
                                                dataset_info=dataset_info)
        print(f"Generated molecule: Positions {x[:-1, :, :]}")
        vis.save_xyz_file(f'outputs/{args.exp_name}/epoch_{epoch}_{batch_id}/', one_hot, charges, x, dataset_info,
                          batch_size * counter, name='molecule')


def analyze_and_save(epoch, model_sample, nodes_dist, args, device, dataset_info, prop_dist,
                     n_samples=1000, batch_size=100):
    print(f'Analyzing molecule stability at epoch {epoch}...')
    batch_size = min(batch_size, n_samples)
    assert n_samples % batch_size == 0
    molecules = {'one_hot': [], 'x': [], 'node_mask': []}
    for i in range(int(n_samples/batch_size)):
        nodesxsample = nodes_dist.sample(batch_size)
        one_hot, charges, x, node_mask = sample(args, device, model_sample, dataset_info, prop_dist,
                                                nodesxsample=nodesxsample)

        molecules['one_hot'].append(one_hot.detach().cpu())
        molecules['x'].append(x.detach().cpu())
        molecules['node_mask'].append(node_mask.detach().cpu())

    molecules = {key: torch.cat(molecules[key], dim=0) for key in molecules}
    validity_dict, rdkit_tuple = analyze_stability_for_molecules(molecules, dataset_info)

    wandb.log(validity_dict)
    if rdkit_tuple is not None:
        wandb.log({'Validity': rdkit_tuple[0][0], 'Uniqueness': rdkit_tuple[0][1], 'Novelty': rdkit_tuple[0][2]})
    return validity_dict


def save_and_sample_conditional(args, device, model, data, prop_dist, dataset_info, epoch=0, id_from=0):
    one_hot, charges, x, node_mask = sample_sweep_conditional(args, device, model, data, dataset_info, prop_dist)

    vis.save_xyz_file(
        'outputs/%s/epoch_%d/conditional/' % (args.exp_name, epoch), one_hot, charges, x, dataset_info,
        id_from, name='conditional', node_mask=node_mask)

    return one_hot, charges, x
